"""@plugin BEASTVulnerabilityTesterPlugin
This file contains classes for plugin, which tests server(s) for vulnerability CVE-2011-3389.

@author: Bc. Pavel Soukup
"""

from operator import attrgetter
from xml.etree.ElementTree import Element

from nassl.ssl_client import SslClient
from sets import Set

from sslyze.plugins.common_new_plugin_info import PROTOCOL_VERSION, OPENSSL_TO_RFC_NAMES_MAPPING, AcceptCipher, RejectCipher
from sslyze.plugins import plugin_base
from sslyze.plugins.plugin_base import PluginResult
from sslyze.utils.ssl_connection import SSLHandshakeRejected, SSLConnection
from sslyze.utils.thread_pool import ThreadPool


class BEASTVulnerabilityTesterPlugin(plugin_base.PluginBase):
    """class BEASTVulnerabilityTesterPlugin
    
    This class inherited from abstract class PluginBase. Instance of this class tests server for vulnerability CVE-2011-3389 and makes decision if the server is vulnerable.
    """

    interface = plugin_base.PluginInterface(
        "BEASTVulnerabilityTesterPlugin",
        "Scans the server(s) and checks if requirements for BEAST attack are satisfied.")
    interface.add_command(command="beast", help="Tests server for BEAST vulnerability.")

    MAX_THREADS=15

    def process_task(self, server_connectivity_info, plugin_command, options_dict=None):
        if options_dict and 'verbose' in options_dict.keys():
            verbose_mode = options_dict['verbose']
        else:
            verbose_mode = False
        test_protocols = {'SSLv3', 'TLSv1'}
        thread_pool = ThreadPool()
        ciphers_list = []
        support_protocol_list=[]
        for proto in test_protocols:
            if self.test_protocol_support(PROTOCOL_VERSION[proto], server_connectivity_info):
                ssl_client=SslClient(ssl_version=PROTOCOL_VERSION[proto])
                ssl_client.set_cipher_list(proto)
                ciphers_list = ssl_client.get_cipher_list()
                for cipher in ciphers_list:
                    thread_pool.add_job((self._test_ciphersuite,(server_connectivity_info,PROTOCOL_VERSION[proto],cipher)))
                support_protocol_list.append(proto)
        thread_pool.start(nb_threads=min(len(ciphers_list),self.MAX_THREADS))

        accept_ciphers=[]
        reject_ciphers=[]
        if verbose_mode:
            print '  VERBOSE MODE PRINT'
            print '  ------------------'

        for completed_job in thread_pool.get_result():
            (job, cipher_result) = completed_job
            if isinstance(cipher_result,AcceptCipher):
                accept_ciphers.append(cipher_result)
            elif isinstance(cipher_result,RejectCipher):
                reject_ciphers.append(cipher_result)
            else:
                raise ValueError("Unexpected result")
            if verbose_mode:
                cipher_result.print_cipher()

        if verbose_mode:
            print '  ----------------------'
            print '  END VERBOSE MODE PRINT'
            print '  ----------------------'

        for error_job in thread_pool.get_error():
            (_, exception) = error_job
            raise exception

        thread_pool.join()

        support_vulnerable_ciphers_set = self.get_vulnerable_ciphers(accept_ciphers)
        is_vulnerable = True \
            if len(support_vulnerable_ciphers_set) > 0 \
            else False

        return BEASTVulnerabilityTesterResult(server_connectivity_info, plugin_command, options_dict, support_vulnerable_ciphers_set, is_vulnerable, support_protocol_list)

    def _test_ciphersuite(self, server_connectivity_info, ssl_version, cipher):
        """This function is used by threads to it investigates with support the cipher suite on server. Returns instance of class AcceptCipher or RejectCipher.

            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server
            ssl_version (str): contains version of SSL/TLS protocol, uses to connect
            cipher (str): contains OpenSSL shortcut for identification cipher suite
        """
        ssl_conn = server_connectivity_info.get_preconfigured_ssl_connection(override_ssl_version=ssl_version)
        ssl_conn.set_cipher_list(cipher)
        try:
            ssl_conn.connect()
        except SSLHandshakeRejected as e:
            cipher_result = RejectCipher(OPENSSL_TO_RFC_NAMES_MAPPING[ssl_version].get(cipher,cipher), str(e))
        else:
            cipher_result = AcceptCipher(OPENSSL_TO_RFC_NAMES_MAPPING[ssl_version].get(cipher,cipher))
        finally:
            ssl_conn.close()
        return cipher_result

    def test_protocol_support(self, ssl_version, server_info):
        """Tests if SSL/TLS protocol version in parameter ssl_version is supported by server. Returns true if server supports protocol version otherwise returns false.

            Args:
            ssl_version (str):
            server_info (ServerConnectivityInfo): contains information for connection on server            
        """
        ssl_conn = server_info.get_preconfigured_ssl_connection(override_ssl_version=ssl_version)
        try:
            ssl_conn.connect()
        except SSLHandshakeRejected:
            support = False
        else:
            support = True
        finally:
            ssl_conn.close()
        return support

    def get_vulnerable_ciphers(self, accept_ciphers):
        """Returns set of cipher suites, which using CBC mode and are supported by server.

            Args:
            accept_ciphers(array): contains cipher suites, which are supported by server
        """
        result_set = Set()
        for cipher in accept_ciphers:
            if cipher.use_CBC_mode():
                result_set = result_set.union([cipher._cipher_rfc_name])
        return result_set

class BEASTVulnerabilityTesterResult(PluginResult):
    """class BEASTVulnerabilityTesterResult
    
    This class is subclass PluginResult. It's used to return result of test, which is made in class BEASTVulnerabilityTesterPlugin.
    """

    COMMAND_TITLE = 'Vulnerability CVE-2011-3389'
    PROTOCOL_LINE_FORMAT='      Support protocols: {protocols:<20}'.format
    CIPHER_LIST_TITLE_FORMAT = '      {section_title:<32}'.format
    CIPHER_LINE_FORMAT = u'        {cipher_name:<50}'.format
    def __init__(self, server_info, plugin_command, plugin_options, support_vulnerable_ciphers, is_vulnerable, support_protocols):
        super(BEASTVulnerabilityTesterResult, self).__init__(server_info, plugin_command, plugin_options)
        self.support_vulnerable_ciphers = support_vulnerable_ciphers
        self.is_vulnerable = is_vulnerable
        self.support_protocols = support_protocols

    def as_text(self):
        beast_txt = 'VULNERABLE - server is vulnerable to BEAST attack' \
            if self.is_vulnerable \
            else 'OK - Not vulnerable to BEAST attack'

        txt_output = [self.PLUGIN_TITLE_FORMAT(self.COMMAND_TITLE)]
        txt_output.append(self.FIELD_FORMAT("", beast_txt))
        protocol_line = ''
        for protocol in self.support_protocols:
            if protocol_line != '':
                protocol_line = protocol_line + ', '
            protocol_line = protocol_line + protocol
        if protocol_line == '':
            protocol_line = 'Both SSL 3.0 and TLS 1.0 are not support'
        txt_output.append(self.PROTOCOL_LINE_FORMAT(protocols=protocol_line))
        txt_output.append(self.CIPHER_LIST_TITLE_FORMAT(section_title='Vulnerable cipher/ciphers:'))
        for c in self.support_vulnerable_ciphers:
            txt_output.append(self.CIPHER_LINE_FORMAT(cipher_name=c))
        if len(self.support_vulnerable_ciphers) == 0:
           txt_output.append(self.CIPHER_LINE_FORMAT(cipher_name='No vulnerable ciphers are supported'))            
        return txt_output
    def as_xml(self):
        xml_output = Element(self.plugin_command, title=self.COMMAND_TITLE)
        xml_output.append(Element('vulnerable', isVulnerable=str(self.is_vulnerable)))
        if len(self.support_protocols) > 0:
            protocol_xml = Element('supportProtocols')
            for p in self.support_protocols:
                protocol_xml.append(Element('protocol',name=p))
            xml_output.append(protocol_xml)

        if len(self.support_vulnerable_ciphers) > 0:
            cipher_xml = Element('vulnerableCipherSuites')
            for c in self.support_vulnerable_ciphers:
                cipher_xml.append(Element('cipherSuite',name=c))
            xml_output.append(cipher_xml)

        return xml_output